import gymnasium as gym
from Environment.environment import Environment
from Environment.Environments.Phyre.level_builder import PHYRETemplate
from Environment.Environments.Phyre.tasks import get_task, list_available_tasks
import numpy as np
from Box2D import b2World, b2_pi, b2Vec2, b2ContactListener
from Environment.Environments.Phyre.rendering import render_scene
import time
import copy
from Environment.Environments.Phyre.objects import PHYREObjectWrapper
from Environment.Environments.Phyre.levels import PHYRELevel
from Environment.Environments.Phyre.phyre_init_specs import generate_object_dicts
from Environment.Environments.Phyre.objects import *
from Environment.Environments.Phyre.phyre_specs import phyre_variants
import sys, os
import yaml

def convert_name(name, counters):
    if name.find("platform") != -1:
        new_name = "Platform" + str(counters["Platform"])
        counters["Platform"] += 1
        return new_name
    elif name == "basket":
        new_name = "Basket" + str(counters["Basket"])
        counters["Basket"] += 1
        return new_name
    elif name == "green_ball":
        new_name = "Target" + str(counters["Target"])
        counters["Target"] += 1
        return new_name
    else:
        new_name = "Ball" + str(counters["Ball"])
        counters["Ball"] += 1
        return new_name

class Action:
    def __init__(self, num_action_objects, passive):
        # The action 0 is no-op
        self.name = "Action"
        self.attribute = 0 if passive else np.zeros((num_action_objects, 2))
        self.continuous = not passive
        self.interaction_trace = list()
        self.passive = passive

    def take_action(self, action):
        self.attribute = action

    def get_state(self):
        return (
            np.array(self.attribute) if self.continuous else np.array([self.attribute])
        )


class GoalContactListener(b2ContactListener):
    def __init__(self, env):
        super(GoalContactListener, self).__init__()
        self.env = env
        self.target_object = env.level.target_object
        self.goal_object = env.level.goal_object

    def BeginContact(self, contact):
        fixture_a = contact.fixtureA
        fixture_b = contact.fixtureB
        # Access user data or other properties to identify specific objects
        if (
            fixture_a.body.userData == self.target_object
            and fixture_b.body.userData == self.goal_object
        ) or (
            fixture_a.body.userData == self.goal_object
            and fixture_b.body.userData == self.target_object
        ):
            # Player hit the target, terminate the episode
            self.env.done = True
            self.env.info["termination"] = "SUCCESS"
            self.env.success = True

    def EndContact(self, contact):
        pass

def solve_level(
    level_dict,
    variant,
    max_trials=100,
    env_config_path="Environment/Environments/Phyre/",
):
    # Load the simulator config from YAML
    env_config = yaml.load(
        open(os.path.join(env_config_path, "config.yaml"), "r"),
        Loader=yaml.FullLoader,
    )
    level = PHYRELevel()
    level.load_dict(level_dict, with_solution=False)
    env = PHYREWorld(variant=variant, test_mode=True)
    env.level = level
    env.level.make_level(env.world, env_config["screen_size"], env_config["screen_size"], env_config["ppm"])
    env.reset()

    # Take random actions until a usable solution is found
    # TODO - using a policy that is actually good would probably be better
    for i in range(max_trials):
        action = env.action_space.sample()
        observation, reward, done, info = env.step(action)
        termination = info["termination"]
        env.reset()
        if termination == "SUCCESS":
            solution = action
            return solution
    return None


def sample_random_level(task, check_solvable=True, max_trials=100):
    # Create a new level with the given task template and populate the object dict
    solved = False
    trial = 0
    if check_solvable:
        while not solved and trial < max_trials:
            level = task.create_level(task)
            solution = solve_level(level, variant=task.name, max_trials=max_trials)
            trial += 1
            if solution is not None:
                solved = True
    else:
        level = task.create_level(task)
    return level

class PHYREWorld(Environment):
    def __init__(
        self, frameskip=1, horizon=200, variant="", fixed_limits=False, renderable=False, test_mode=False
    ):
        super().__init__(
            frameskip=frameskip, variant=variant, fixed_limits=fixed_limits
        )
        # self.screen_size, self.world_size, self.max_steps, self.passive, self.task_names, self.min_objects, self.max_objects,

        self.name = "Phyre"
        self.fixed_limits = fixed_limits  # isn't used for now
        self.renderable = renderable
        self.world = b2World(gravity=(0, -10), doSleep=True)
        self.variant = variant
        self.test_mode = test_mode

        self.screen_size, self.world_size, self.fps, \
            self.vel_iters, self.pos_iters, \
            self.min_action_radius, self.max_action_radius, self.max_action_objects, \
            self.min_num_null, self.max_num_null, self.force_live, \
            self.task_names, self.max_steps = phyre_variants[self.variant]
        self.ppm = int(self.screen_size / self.world_size)

        # If set, actions will just be a dummy 0 and the environment decides all dynamics
        # This is only used for data collection - starting states of the action objects are preset
        self.passive = True
        self.continuous_actions = not self.passive

        # Load task list
        possible_tasks = list_available_tasks()
        self.tasks = list()
        for tn in self.task_names:
            if tn not in possible_tasks:
                raise NotImplementedError(
                    "Task {} not available. Available variants are {}".format(
                        tn, possible_tasks
                    )
                )
            self.tasks.append(get_task(tn))

        # Set up action space
        num_action_objects = max(self.max_action_objects, max([len(task.action_objects) for task in self.tasks]))
        self.num_action_objects = num_action_objects
        if self.passive:
            self.action_space = gym.spaces.Box(
                low=np.array([0]),
                high=np.array([1]),
                dtype=np.float32,
            )
        else:
            action_space_low = np.tile(
                np.array(
                    [-self.world_size * 0.5, -self.world_size * 0.5]
                ),
                (num_action_objects, 1),
            )
            action_space_high = np.tile(
                np.array(
                    [self.world_size * 0.5, self.world_size * 0.5]
                ),
                (num_action_objects, 1),
            )
            # If there is only one action_object, then the action_object space should be just 2D
            if num_action_objects == 1:
                action_space_low = action_space_low.flatten()
                action_space_high = action_space_high.flatten()

            self.action_space = gym.spaces.Box(
                low=action_space_low,
                high=action_space_high,
                dtype=np.float32,
            )

        # The origin is at the center of the frame
        self.observation_space = gym.spaces.Box(
            low=np.array(
                [-self.world_size * 0.5, -self.world_size * 0.5]
            ),
            high=np.array(
                [self.world_size * 0.5, self.world_size * 0.5]
            ),
            dtype=np.float32,
        )

        self.pos_size = 2

        self.action = Action(num_action_objects, self.passive)
        self.extracted_state = None

        # Running values
        self.itr = 0
        self.total_itr = 0
        self.max_steps = horizon if horizon > 0 else self.max_steps
        self.needs_reset = False

        # Factorized state properties
        self.all_names, self.object_names = self.generate_names()
        # print(self.all_names, self.object_names)
        self.num_objects = len(self.all_names)
        (
            self.object_sizes,
            self.object_range,
            self.object_dynamics,
            self.object_range_true,
            self.object_dynamics_true,
            self.position_masks,
            self.object_instanced
        ) = generate_object_dicts(num_action_objects, self.max_action_objects, self.object_names, self.all_names, self.world_size, self.passive)

        # self.object_instanced = self.generate_instancing()  # TODO: support instancing
        self.object_proximal = dict()
        for name in self.object_names:
            if name in ["Action", "Reward", "Done"]:
                self.object_proximal[name] = False
            else:
                self.object_proximal[name] = True
        if not self.test_mode:
            print("Created Phyre world")
        self.reset()

    def initialize_object_dict(self):
        self.object_name_dict = {
            "Action": self.action,
            "Reward": self.reward,
            "Done": self.done,
        }
        for name in self.all_names:
            if name not in ["Action", "Reward", "Done"]:
                body_name = self.task_name_dict[self.task_choice.name][0][name] if name in self.task_name_dict[self.task_choice.name][0] else name
                body = self.level.bodies[body_name] if body_name in self.level.bodies else None
                type_name = name.strip("0123456789") # TODO: assumes only numbers are instance identifiers
                dynamic = self.level.objects[body_name].dynamic if body_name in self.level.objects else False # false is unusued, just a placeholder
                
                self.object_name_dict[name] = PHYREObjectWrapper(
                    name, body_name, body, type_name, dynamic
                )
        # The objects list is used for data generation
        self.objects = [self.object_name_dict[name] for name in self.all_names]
    

    def generate_names(self):
        all_names = set()
        self.task_name_dict = dict()
        for task in self.tasks:
            self.task_name_dict[task.name] = [dict(), dict()]
            counters = {"Platform": 0, "Target": 0, "Ball": 0, "Basket": 0}
            task_object_names = list(task.objects.keys())
            task_object_names.sort() # ensure consistent ordering of keys
            for obj_name in task_object_names:
                add_name = convert_name(obj_name, counters) # increments counter in convert_name
                self.task_name_dict[task.name][0][add_name] = obj_name 
                self.task_name_dict[task.name][1][obj_name] = add_name 
                all_names.add(add_name)
            for i in range(self.num_action_objects - 1):
                # TODO: assumes that the environment contains exactly 1 action object
                # AND all additional action objects are red balls
                obj_name = "red_ball" + str(i)
                add_name = convert_name(obj_name, counters) # increments counter in convert_name
                self.task_name_dict[task.name][0][add_name] = obj_name 
                self.task_name_dict[task.name][1][obj_name] = add_name 
                all_names.add(add_name)
            # print("task objects", task.name, task_object_names, all_names)
        for name in counters.keys():
            if name + str(1) not in all_names and name + str(0) in all_names:
                all_names.add(name)
                all_names.remove(name + str(0))
                for task in self.tasks:
                    tnd = self.task_name_dict[task.name]
                    if name + str(0) in tnd[0]:
                        oname = copy.deepcopy(tnd[0][name + str(0)])
                        tnd[0][name] = oname
                        tnd[1][oname] = copy.deepcopy(name)
                        del tnd[0][name + str(0)]
                    # print(task.name, self.task_name_dict[task.name])

        all_names = list(all_names)
        all_names.sort()
        all_names = ["Action"] + all_names + ["Reward", "Done"]
        
        object_names = list(set([n.strip("0123456789") for n in all_names]))
        object_names.sort()
        
        return all_names, object_names

    def reset(self, task_choice_idx=-1, valid_names=None):
        self.itr = 0
        self.action = Action(self.max_action_objects, self.passive)
        self.reward.attribute = 0.0
        self.done.attribute = False

        # Reset the world
        self.world.ClearForces()
        for body in self.world.bodies:
            self.world.DestroyBody(body)

        # Object dict follows from Box2dObjWrapper and BouncingShapes
        # Need a sample level with created bodies in order to initialize the dict
        # Refactor later
        if task_choice_idx < 0:
            self.task_choice_idx = np.random.choice(np.arange(len(self.tasks)))
        else:
            self.task_choice_idx = task_choice_idx
        self.task_choice = self.tasks[self.task_choice_idx]
        self.level = PHYRELevel(self.task_choice.create_level(level_name=self.task_choice.name + str(self.itr)))

        # initialize the action objects
        # TODO: always assumes there is already one action ball, called "red_ball", otherwise it should be max_action_objects + 1
        num_action_balls = np.random.randint(0, self.max_action_objects)
        for i in range(num_action_balls):
            radius = np.random.rand() * (self.max_action_radius - self.min_action_radius) + self.min_action_radius
            action_object = Ball(np.random.uniform(-4.5, 4.5), np.random.uniform(-2, 4), radius, "red", True) # TODO: hardcoded xy ranges
            self.level.objects["red_ball" + str(i)] = action_object
            self.level.action_objects.append("red_ball" + str(i))
        num_level_valid = len(list(self.level.objects.keys()))

        # apply nulling operations
        if valid_names is None:
            self.instance_length = num_level_valid - np.random.randint(self.min_num_null, self.max_num_null + 1) # Action, reward and done are always valid (but not control)
            use_ids = list(self.level.objects.keys())
            if len(self.force_live) > 0 and self.force_live in use_ids:
                self.instance_length -= 1
                use_ids.pop(use_ids.index(self.force_live))
            names = np.random.choice(use_ids, size=self.instance_length, replace=False).tolist() + ([self.force_live] if (self.force_live and self.force_live in list(self.level.objects.keys())) else [])
        else: 
            names = [self.task_name_dict[self.task_choice.name][0][n] for n in valid_names if n not in ["Action", "Reward", "Done"]]
        remove = list()
        for name in self.level.objects.keys():
            if name not in names: 
                remove.append(name)
                if name in self.level.action_objects:
                    self.level.action_objects.pop(self.level.action_objects.index(name))
        for name in remove:
            del self.level.objects[name]
        
        self.level.make_level(self.world, self.world_size)
        self.generate_from_env()

    def detect_stationary_world(self, level, tolerance=1e-3):
        # Check if all objects are stationary
        for body in self.world.bodies:
            if body.userData in level.objects:
                if (
                    body.linearVelocity.length > tolerance
                    or body.angularVelocity > tolerance
                ):
                    return False
        return True

    def get_contacts(self):
        objects_list = list([self.task_name_dict[self.task_choice.name][0][n] for n in self.valid_names if n not in ["Action", "Reward", "Done"]])
        # print(objects_list, self.valid_names, list(self.level.bodies.keys()))
        contact_matrix = np.zeros((len(objects_list), len(objects_list))).astype(bool)
        object_ids = {name: i for i, name in enumerate(objects_list)}
        contact_names = {name: [] for name in objects_list}
        for obj_name in objects_list:
            obj_contacts = self.level.bodies[obj_name].contacts
            obj_idx = object_ids[obj_name]
            for contact in obj_contacts:
                contact_obj_name = contact.other.userData
                if contact_obj_name in objects_list:
                    if contact.contact.touching:
                        contact_names[obj_name].append(contact_obj_name)
                        contact_idx = object_ids[contact_obj_name]
                        contact_matrix[obj_idx, contact_idx] = True
                        contact_matrix[contact_idx, obj_idx] = True
        converted_names = {self.task_name_dict[self.task_choice.name][1][cn]: 
                          [self.task_name_dict[self.task_choice.name][1][ctdn] for ctdn in contact_names[cn]]
                          for cn in objects_list}

        return contact_matrix, converted_names
    
    def render(self, mode="human"):
        frame = np.ones((self.screen_size, self.screen_size, 3), dtype=np.uint8) * 255
        render_scene(self.world, self.level, frame, self.ppm)
        return frame

    def get_state(self, render=False):
        # TODO - returns a 2d vector if not rendering as in BouncingShapes but this should change
        if render:
            raw_state = self.render()
        else:
            raw_state = np.zeros((2, 2))
        # Get factored state by calling get_state on PHYREObjectWrappers
        factored_state = {
            "Action": np.array(self.action.attribute),
            "Reward": np.array(self.reward.attribute),
            "Done": np.array(self.done.attribute),
        }
        for name in self.all_names:
            # TODO - do these need to be rounded to 5 decimal places?
            factored_state[name] = self.object_name_dict[name].get_state()
        factored_state["VALID_NAMES"] = self.valid_binary(self.valid_names)
        factored_state["TASK_NAME"] = self.task_choice_idx
        return {"raw_state": raw_state, "factored_state": factored_state}

    def generate_from_env(self):
        valid_tobj_names = list(self.level.objects.keys())
        # print("generating", self.task_choice.name, valid_tobj_names)
        valid_obj_names = [self.task_name_dict[self.task_choice.name][1][n] for n in valid_tobj_names]
        valid_obj_names.sort()
        self.valid_names = ["Action", "Reward", "Done"] + valid_obj_names
        # TODO - generate_from_env and initialize_object_dict seem to be doing the same thing
        self.initialize_object_dict()

    def get_itr(self):
        return self.itr

    def set_from_factored_state(self, factored_state, valid_names):
        self.reset(valid_names = valid_names, task_choice_idx=int(factored_state["TASK_NAME"]))
        for name in self.level.objects.keys():
            obj_name = self.task_name_dict[self.task_choice.name][1][name]
            ofs = factored_state[self.task_name_dict[self.task_choice.name][1][name]]
            object_attrs = attrs_from_state(ofs, obj_name)
            # remove the old object and replace it with a new one
            self.world.DestroyBody(self.level.bodies[name])
            self.level.objects[name] = object_attrs
            self.level.bodies[name] = create_object(self.world, name, obj_name, object_attrs)
            self.object_name_dict[obj_name].object = self.level.bodies[name]
            set_velocities_from_state(ofs, self.object_name_dict[obj_name].object, obj_name.strip("0123456789"))
        self.generate_from_env() # TODO: debug this function


    def reset_traces(self):
        for obj in self.objects:
            obj.interaction_trace = list()

    def step(self, action, render=False):

        self.info = {}
        # Set positions for all action objects
        if self.passive:
            action = 0
        elif len(self.task.action_objects) == 1:
            action = [action]
        else:
            for i, obj_name in enumerate(self.level.action_objects):
                target_position = action[i]
                target_position = b2Vec2(
                    float(target_position[0]), float(target_position[1])
                )
                self.level.bodies[obj_name].position = target_position
        self.action.attribute = action

        # Run the simulation for a fixed number of steps
        # if not passive, assumes that the state is ALREADY reset
        # This is so that the current state can be used as the initial state
        # and after stepping, the state is the outcome
        # TODO: action handling unclear in this case
        run_steps = self.frameskip if self.passive else self.max_steps
        self.done.attribute = False

        self.reset_traces()
        for i in range(run_steps):
            # Step Box2D simulation
            time_step = 1.0 / self.fps
            self.world.Step(time_step, self.vel_iters, self.pos_iters)
            self.itr += 1

            # Get contacts
            contact_matrix, contact_names = self.get_contacts()
            for name in contact_names:
                # print(contact_names, self.task_name_dict[self.task_choice.name])
                # obj_name = self.task_name_dict[self.task_choice.name][0][name]
                self.object_name_dict[name].interaction_trace += contact_names[name]

            # Check if the world is stationary and kill the simulation if so
            if self.detect_stationary_world(self.level):
                self.info["termination"] = "STATIONARY_WORLD"
                self.reward.attribute = 0.0
                self.done.attribute = True
                # print("stationary")

            # Clear the screen and render the world
            if render:
                frame = self.render()
                time.sleep(1.0 / self.fps)

            # If timeout, zero reward
            if self.itr >= self.max_steps:
                self.info["termination"] = "TIMEOUT"
                self.reward.attribute = 0.0
                self.done.attribute = True
                # print("timed out")

        # Return the observation, reward, done, and info
        obs = self.get_state(render)
        if self.done.attribute:
            self.needs_reset = True
            if self.passive: self.reset()
        return obs, self.reward.attribute, self.done.attribute, self.info

        # TODO - add interaction trace

    def toString(self, extracted_state):
        '''
        converts an extracted state into a string for printing. Note this might be overriden since self.objects is not a guaranteed attribute
        '''
        estring = super().toString(extracted_state)
        if "TASK_NAME" in extracted_state:
            estring += "TASK_NAME:" + str(extracted_state['TASK_NAME']) + "\t"
        # estring += "Reward:" + str(float(extracted_state["Reward"])) + "\t"
        # estring += "Done:" + str(int(extracted_state["Done"])) + "\t"
        # print(estring)
        return estring
